Time series clustering is to partition time series data into groups based on similarity or distance, so that time series in the same cluster are similar.
Methodology followed:
from vrae.vrae import VRAE
from vrae.utils import *
from vrae.utils_EMG import *
import numpy as np
import torch
import pickle
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.manifold import TSNE
from sklearn.metrics import mean_squared_error as mse
import plotly
from torch.utils.data import DataLoader, TensorDataset
plotly.offline.init_notebook_mode()
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
dload = './model_dir'
seq_len = 10
hidden_size = 256
hidden_layer_depth = 3
latent_length = 16
batch_size = 32
learning_rate = 0.00002
n_epochs = 1500
dropout_rate = 0.0
optimizer = 'Adam' # options: ADAM, SGD
cuda = True # options: True, False
print_every=10
clip = True # options: True, False
max_grad_norm=5
loss = 'MSELoss' # options: SmoothL1Loss, MSELoss
block = 'LSTM' # options: LSTM, GRU
output = False
training_file = ['20201020_Pop_Cage_001','20201020_Pop_Cage_002','20201020_Pop_Cage_003','20201020_Pop_Cage_004',
'20201020_Pop_Cage_006']
X_train, y_train = load_data(direc = 'data', dataset="EMG", all_file = training_file,
do_pca = False, single_channel = None,
batch_size = batch_size, seq_len = seq_len, pca_component = 6)
train_dataset = TensorDataset(torch.from_numpy(X_train))
Loading 20201020_Pop_Cage_001, X shape (3599, 150, 1), y shape (3599, 1), has label [-1. 0. 1. 2. 3.] Loading 20201020_Pop_Cage_002, X shape (3599, 150, 1), y shape (3599, 1), has label [-1. 0. 1. 2. 3.] Loading 20201020_Pop_Cage_003, X shape (3599, 150, 1), y shape (3599, 1), has label [-1. 0. 1. 2. 3. 4.] Loading 20201020_Pop_Cage_004, X shape (3601, 150, 1), y shape (3601, 1), has label [-1. 0. 1. 2. 3. 4.] Loading 20201020_Pop_Cage_006, X shape (3599, 150, 1), y shape (3599, 1), has label [-1. 0. 1. 2. 3. 4.] Dataset shape: (17984, 10, 15) Label: [-1. 0. 1. 2. 3. 4.], shape: (17984, 1)
num_features = X_train.shape[2]
VRAE inherits from sklearn.base.BaseEstimator and overrides fit, transform and fit_transform functions, similar to sklearn modules
from vrae.vrae import VRAE
vrae = VRAE(sequence_length=seq_len,
number_of_features = num_features,
hidden_size = hidden_size,
hidden_layer_depth = hidden_layer_depth,
latent_length = latent_length,
batch_size = batch_size,
learning_rate = learning_rate,
n_epochs = n_epochs,
dropout_rate = dropout_rate,
optimizer = optimizer,
cuda = cuda,
print_every=print_every,
clip=clip,
max_grad_norm=max_grad_norm,
loss = loss,
block = block,
dload = dload,
output = output)
#vrae.fit(train_dataset)
#If the model has to be saved, with the learnt parameters use:
vrae.fit(train_dataset)
Epoch: 9 Average loss: 5891727.6573 Epoch: 19 Average loss: 4115730.8946 Epoch: 29 Average loss: 2934408.1354 Epoch: 39 Average loss: 2171484.7346 Epoch: 49 Average loss: 1688999.0603 Epoch: 59 Average loss: 1379960.1896 Epoch: 69 Average loss: 1177920.1288 Epoch: 79 Average loss: 1035423.7469 Epoch: 89 Average loss: 921003.2425 Epoch: 99 Average loss: 823723.0066 Epoch: 109 Average loss: 751811.7032 Epoch: 119 Average loss: 699223.2283 Epoch: 129 Average loss: 660168.7892 Epoch: 139 Average loss: 628558.9083 Epoch: 149 Average loss: 600079.6541 Epoch: 159 Average loss: 575455.3235 Epoch: 169 Average loss: 554367.3257 Epoch: 179 Average loss: 535073.4193 Epoch: 189 Average loss: 517000.6199 Epoch: 199 Average loss: 501761.7260 Epoch: 209 Average loss: 488896.7342 Epoch: 219 Average loss: 477162.7586 Epoch: 229 Average loss: 466896.5268 Epoch: 239 Average loss: 457814.0987 Epoch: 249 Average loss: 449363.7805 Epoch: 259 Average loss: 441471.1488 Epoch: 269 Average loss: 434257.0092 Epoch: 279 Average loss: 427321.2723 Epoch: 289 Average loss: 420805.0458 Epoch: 299 Average loss: 414602.4329 Epoch: 309 Average loss: 408619.3156 Epoch: 319 Average loss: 402770.9266 Epoch: 329 Average loss: 397269.9666 Epoch: 339 Average loss: 391942.1424 Epoch: 349 Average loss: 386599.0610 Epoch: 359 Average loss: 381440.0871 Epoch: 369 Average loss: 376490.8925 Epoch: 379 Average loss: 371549.9355 Epoch: 389 Average loss: 367174.0705 Epoch: 399 Average loss: 362425.5994 Epoch: 409 Average loss: 358037.4904 Epoch: 419 Average loss: 353642.3020 Epoch: 429 Average loss: 349512.5671 Epoch: 439 Average loss: 345479.9834 Epoch: 449 Average loss: 341614.6800 Epoch: 459 Average loss: 337528.9933 Epoch: 469 Average loss: 333843.5924 Epoch: 479 Average loss: 330223.8180 Epoch: 489 Average loss: 326884.7638 Epoch: 499 Average loss: 323185.5307 Epoch: 509 Average loss: 319776.3792 Epoch: 519 Average loss: 316500.8656 Epoch: 529 Average loss: 313308.3086 Epoch: 539 Average loss: 310155.8990 Epoch: 549 Average loss: 307100.1049 Epoch: 559 Average loss: 304033.9319 Epoch: 569 Average loss: 301311.6086 Epoch: 579 Average loss: 298256.1927 Epoch: 589 Average loss: 295631.7679 Epoch: 599 Average loss: 292857.2085 Epoch: 609 Average loss: 290344.8730 Epoch: 619 Average loss: 287741.0361 Epoch: 629 Average loss: 285105.4003 Epoch: 639 Average loss: 282657.9251 Epoch: 649 Average loss: 280240.3975 Epoch: 659 Average loss: 278019.9135 Epoch: 669 Average loss: 275609.9852 Epoch: 679 Average loss: 273252.8737 Epoch: 689 Average loss: 271166.8934 Epoch: 699 Average loss: 268942.5618 Epoch: 709 Average loss: 266854.8537 Epoch: 719 Average loss: 264564.6353 Epoch: 729 Average loss: 262639.8959 Epoch: 739 Average loss: 260465.2851 Epoch: 749 Average loss: 258618.8753 Epoch: 759 Average loss: 256636.7499 Epoch: 769 Average loss: 254866.4674 Epoch: 779 Average loss: 252964.1101 Epoch: 789 Average loss: 251018.5461 Epoch: 799 Average loss: 249240.3056 Epoch: 809 Average loss: 247545.9195 Epoch: 819 Average loss: 245734.2310 Epoch: 829 Average loss: 244093.0183 Epoch: 839 Average loss: 242424.9341 Epoch: 849 Average loss: 240826.7157 Epoch: 859 Average loss: 239157.0223 Epoch: 869 Average loss: 237567.9043 Epoch: 879 Average loss: 235863.3427 Epoch: 889 Average loss: 234406.6157 Epoch: 899 Average loss: 232878.4159 Epoch: 909 Average loss: 231531.5166 Epoch: 919 Average loss: 229938.2736 Epoch: 929 Average loss: 228655.6076 Epoch: 939 Average loss: 227071.2267 Epoch: 949 Average loss: 225715.2936 Epoch: 959 Average loss: 224197.8663 Epoch: 969 Average loss: 222919.4361 Epoch: 979 Average loss: 221430.4985 Epoch: 989 Average loss: 220213.0969 Epoch: 999 Average loss: 218988.6693 Epoch: 1009 Average loss: 217603.1820 Epoch: 1019 Average loss: 216310.6568 Epoch: 1029 Average loss: 215119.0043 Epoch: 1039 Average loss: 213855.2409 Epoch: 1049 Average loss: 212701.7637 Epoch: 1059 Average loss: 211558.0679 Epoch: 1069 Average loss: 210379.6927 Epoch: 1079 Average loss: 209151.8073 Epoch: 1089 Average loss: 207994.3229 Epoch: 1099 Average loss: 206961.9330 Epoch: 1109 Average loss: 205627.1286 Epoch: 1119 Average loss: 204686.1063 Epoch: 1129 Average loss: 203561.6146 Epoch: 1139 Average loss: 202419.8964 Epoch: 1149 Average loss: 201204.2227 Epoch: 1159 Average loss: 200277.3083 Epoch: 1169 Average loss: 199250.2502 Epoch: 1179 Average loss: 198060.8015 Epoch: 1189 Average loss: 197138.0566 Epoch: 1199 Average loss: 196203.8653 Epoch: 1209 Average loss: 195259.4156 Epoch: 1219 Average loss: 194127.2683 Epoch: 1229 Average loss: 193080.3973 Epoch: 1239 Average loss: 192340.7876 Epoch: 1249 Average loss: 191469.0995 Epoch: 1259 Average loss: 190430.9854 Epoch: 1269 Average loss: 189498.3008 Epoch: 1279 Average loss: 188411.7199 Epoch: 1289 Average loss: 187802.8450 Epoch: 1299 Average loss: 186643.5880 Epoch: 1309 Average loss: 185876.3934 Epoch: 1319 Average loss: 184962.4783 Epoch: 1329 Average loss: 184109.4077 Epoch: 1339 Average loss: 183104.1834 Epoch: 1349 Average loss: 182360.9208 Epoch: 1359 Average loss: 181450.0596 Epoch: 1369 Average loss: 180673.2927 Epoch: 1379 Average loss: 179717.8998 Epoch: 1389 Average loss: 178847.6470 Epoch: 1399 Average loss: 178302.4848 Epoch: 1409 Average loss: 177338.9082 Epoch: 1419 Average loss: 176589.0585 Epoch: 1429 Average loss: 175780.3559 Epoch: 1439 Average loss: 174929.9742 Epoch: 1449 Average loss: 174110.5722 Epoch: 1459 Average loss: 173303.1827 Epoch: 1469 Average loss: 172645.1320 Epoch: 1479 Average loss: 171976.3129 Epoch: 1489 Average loss: 171007.7986 Epoch: 1499 Average loss: 170450.0694
plt.plot(vrae.all_loss)
[<matplotlib.lines.Line2D at 0x7f2f0147cd60>]
plt.plot(vrae.rec_mse)
[<matplotlib.lines.Line2D at 0x7f2f01333d30>]
#If the latent vectors have to be saved, pass the parameter `save`
z_run = vrae.transform(train_dataset, save = True, filename = 'z_run_e57_b32_z16.pkl')
z_run.shape
(17984, 16)
vrae.save('./vrae_e57_b32_z16.pth')
vrae.load(dload+'/vrae_e5_3000epoch.pth')
with open(dload+'/z_run_e57pca_2000epoch.pkl', 'rb') as fh:
z_run = pickle.load(fh)
reconstruction = recon(vrae, X_train)
plot_recon_feature(X_train, reconstruction, idx = None)
_, _, _ = plot_recon_metrics(X_train, reconstruction, x_lim = [2000, 4000])
Channel 1, corr = 0.7021, mse = 34.952081, mean = 29.5886. Channel 2, corr = 0.6713, mse = 30.861576, mean = 27.4895. Channel 3, corr = 0.6417, mse = 44.195031, mean = 31.6063. Channel 4, corr = 0.5699, mse = 28.199757, mean = 19.6259. Channel 5, corr = 0.6017, mse = 19.871269, mean = 13.4139. Channel 6, corr = 0.6487, mse = 40.137154, mean = 32.0427. Channel 7, corr = 0.8288, mse = 41.696226, mean = 49.2383. Channel 8, corr = 0.8158, mse = 51.406393, mean = 54.5515. Channel 9, corr = 0.6582, mse = 23.042701, mean = 21.3511. Channel 10, corr = 0.6929, mse = 39.546940, mean = 30.8874. Channel 11, corr = 0.8106, mse = 27.314303, mean = 46.5397. Channel 12, corr = 0.6312, mse = 33.455014, mean = 21.5676. Channel 13, corr = 0.8507, mse = 37.975793, mean = 50.0767. Channel 14, corr = 0.7943, mse = 38.463172, mean = 39.4550. Channel 15, corr = 0.7560, mse = 38.891901, mean = 36.4381.
#recon_channel = pca_inverse(X_pca, reconstruction)
#plot_recon_feature(X_train_ori, recon_channel, idx = None)
#_, _, _ = plot_recon_metrics(X_train_ori, recon_channel, x_lim = [0, 2000])
testing_file = ['20201020_Pop_Cage_005', '20201020_Pop_Cage_007']
X_test, y_test = load_data(direc = 'data', dataset="EMG", all_file = testing_file,
do_pca = False, single_channel = None,
batch_size = batch_size, seq_len = seq_len, pca_component = 6)
Loading 20201020_Pop_Cage_005, X shape (3599, 150, 1), y shape (3599, 1), has label [-1. 0. 1. 2. 3. 5.] Loading 20201020_Pop_Cage_007, X shape (3599, 150, 1), y shape (3599, 1), has label [-1. 0. 1. 2. 3. 4.] Dataset shape: (7168, 10, 15) Label: [-1. 0. 1. 2. 3. 4. 5.], shape: (7168, 1)
# Uncomment if using pca
recon_test = recon(vrae, X_test)
# recon_channel_test = pca_inverse(test_pca, recon_test)
plot_recon_feature(X_test, recon_test, idx = None)
# plot_recon_feature(X_test_ori, recon_channel_test, idx = None)
corr_mean, mse_mean, mean_mean = plot_recon_metrics(X_test, recon_test, x_lim = [0, 2000])
# corr_mean, mse_mean, mean_mean = plot_recon_metrics(X_test_ori, recon_channel_test, x_lim = [0, 2000])
Channel 1, corr = 0.6114, mse = 81.515949, mean = 29.5584. Channel 2, corr = 0.5870, mse = 63.216851, mean = 27.4149. Channel 3, corr = 0.5166, mse = 98.827603, mean = 32.5453. Channel 4, corr = 0.4702, mse = 43.889730, mean = 19.5454. Channel 5, corr = 0.4966, mse = 33.506293, mean = 13.0227. Channel 6, corr = 0.4695, mse = 190.091952, mean = 30.9896. Channel 7, corr = 0.7065, mse = 225.744413, mean = 50.5355. Channel 8, corr = 0.7108, mse = 208.645206, mean = 55.5833. Channel 9, corr = 0.5775, mse = 43.315605, mean = 20.8698. Channel 10, corr = 0.5762, mse = 109.909638, mean = 31.7638. Channel 11, corr = 0.6703, mse = 404.620769, mean = 46.6334. Channel 12, corr = 0.4810, mse = 89.241052, mean = 21.2069. Channel 13, corr = 0.7461, mse = 196.812028, mean = 52.4404. Channel 14, corr = 0.6554, mse = 190.569909, mean = 41.4571. Channel 15, corr = 0.6240, mse = 174.658730, mean = 38.8859.
print(list(corr_mean))
print(list(mse_mean))
print(list(mean_mean))
[0.6114096015015997, 0.5869850854222699, 0.5166355521135001, 0.47023718456402336, 0.4965712216454184, 0.4694611117575641, 0.7065237096137096, 0.7107523832229438, 0.5774546668196491, 0.5762299144335594, 0.6703044345980109, 0.4809608593894534, 0.746060185431153, 0.6554411887144711, 0.6239657718506972] [81.51594886623562, 63.21685149152076, 98.82760314330491, 43.88973006221668, 33.50629320109916, 190.09195200703334, 225.74441302958232, 208.64520648324037, 43.31560483996647, 109.90963842060329, 404.62076911667515, 89.24105215032186, 196.81202796868124, 190.56990884790193, 174.6587300240894] [29.558443014706768, 27.414924678304136, 32.54528668084857, 19.545427740051515, 13.022679916323572, 30.98961445901738, 50.535490176304165, 55.58334129172573, 20.86980557166033, 31.763762002375774, 46.63335753107305, 21.206949619877204, 52.44043212360176, 41.45713725737416, 38.88586835658078]
bhvs = {'crawling': np.array([0]),
'high picking treats': np.array([1]),
'low picking treats': np.array([2]),
'pg': np.array([3]),
'sitting still': np.array([4]),
'grooming': np.array([5]),
'no_behavior': np.array([-1])}
inv_bhvs = {int(v): k for k, v in bhvs.items()}
test_dataset = TensorDataset(torch.from_numpy(X_test))
z_run_test = vrae.transform(test_dataset, save = False)
z_run_all = np.vstack((z_run, z_run_test))
y_all = np.vstack((y_train, y_test))
visualize(z_run = z_run_all, y = y_all, inv_bhvs = inv_bhvs, one_in = 4)